import paddle
import numpy as np

class SSPred():
    def __init__(self,
                 num_classes=20,
                 out_strides=[32, 16, 8],
                 sco_threshs=0.70,
                 nms_threshs=0.45):
        """
        初始化预测结果
        * params:
        - num_classes: 物体类别数量
        - out_strides: 输出下采样率
        - sco_threshs: 预测得分阈值
        - nms_threshs: 非极大值阈值
        """
        super().__init__()

        self.num_classes = num_classes # 物体类别数量
        self.out_strides = out_strides # 输出下采样率
        self.sco_threshs = sco_threshs # 预测得分阈值
        self.nms_threshs = nms_threshs # 非极大值阈值

    def __call__(self, p_list, imghws):
        """
        计算预测结果
        * params:
        - p_list: 预测特征
        - imghws: 图像高宽
        * return:
        - infers: 预测结果，形状为[b,n,6]，b为批次数量，n为物体数量，6为[class,score,x1,y1,x2,y2]
        """
        # 计算边框得分
        boxes, score = self.get_boxes_score(p_list, imghws)

        # 计算预测结果
        infers = self.multi_class_nms(boxes, score)
        
        return infers

#########################################################################################

    @paddle.no_grad()
    def get_boxes_score(self, p_list, imghws):
        """
        计算边框得分
        params:
        - p_list: 预测特征
        - imghws: 图像高宽
        return:
        - boxes : 预测边框
        - score : 预测得分
        """
        # 计算边框得分
        pdbox_list = [] # 预测边框列表
        pdsco_list = [] # 预测得分列表
        for infer, stride in zip(p_list, self.out_strides): # 遍历网络输出
            # 计算预测边框
            pdloc = infer[:,  :4, :, :]                   # 预测位置特征，形状为[b,4,h,w]
            pdbox = self.get_pdbox(pdloc, imghws, stride) # 计算预测边框，形状为[b,4,h*w]

            # 计算预测得分
            pdobj = infer[:, 4:5, :, :]          # 预测物体特征，形状为[b,1,h,w]
            pdcls = infer[:, 5: , :, :]          # 预测类别特征，形状为[b,c,h,w]
            pdsco = self.get_pdsco(pdobj, pdcls) # 计算预测得分，形状为[b,c,h*w]

            # 添加预测结果
            pdbox_list.append(pdbox) # 添加边框列表
            pdsco_list.append(pdsco) # 添加得分列表

        # 合并预测结果
        boxes = paddle.concat(pdbox_list, axis=-1) # 合并最后维度，形状为[b,4,m]
        score = paddle.concat(pdsco_list, axis=-1) # 合并最后维度，形状为[b,c,m]
        
        return boxes, score

    def get_pdbox(self, pdloc, imghws, stride):
        """
        计算预测边框
        * 参数:
        - pdloc : 预测位置
        - imghws: 图像高宽
        - stride: 下采样率
        * 返回:
        - pdbox : 预测边框
        """
        # 计算网格坐标
        grid_h = pdloc.shape[-2] # 网格高度
        grid_w = pdloc.shape[-1] # 网格宽度

        yv, xv = paddle.meshgrid([paddle.arange(grid_h), paddle.arange(grid_w)])
        grid_o = paddle.stack([xv, yv], axis=-1).astype('float32') # 网格原点，形状为[h,w,2]
        grid_x = grid_o[:, :, 0].unsqueeze(0)                      # 网格坐标，形状为[1,h,w]
        grid_y = grid_o[:, : ,1].unsqueeze(0)                      # 网格坐标，形状为[1,h,w]

        # 计算预测边框
        imgs_h = grid_h * stride # 图像高度
        imgs_w = grid_w * stride # 图像宽度

        pd_tx = pdloc[:, 0, :, :] # 预测位置中心，形状为[b,h,w]
        pd_ty = pdloc[:, 1, :, :] # 预测位置中心，形状为[b,h,w]
        pd_tw = pdloc[:, 2, :, :] # 预测位置宽度，形状为[b,h,w]
        pd_th = pdloc[:, 3, :, :] # 预测位置高度，形状为[b,h,w]

        pd_px = (grid_x + paddle.nn.functional.sigmoid(pd_tx)) / grid_w # px=cx+sigmoid(tx)
        pd_py = (grid_y + paddle.nn.functional.sigmoid(pd_ty)) / grid_h # py=cy+sigmoid(ty)
        pd_pw = paddle.exp(pd_tw) / imgs_w                              # pw=exp(tw)
        pd_ph = paddle.exp(pd_th) / imgs_h                              # ph=exp(th)

        pdbox = paddle.stack([pd_px, pd_py, pd_pw, pd_ph], axis=1) # 计算预测边框，形状为[b,4,h,w]

        # 调整坐标格式
        pdbox[:, 0, :, :] = pdbox[:, 0, :, :] - 0.5 * pdbox[:, 2, :, :] # 边框坐标x1
        pdbox[:, 1, :, :] = pdbox[:, 1, :, :] - 0.5 * pdbox[:, 3, :, :] # 边框坐标y1
        pdbox[:, 2, :, :] = pdbox[:, 0, :, :] + pdbox[:, 2, :, :]       # 边框坐标x2
        pdbox[:, 3, :, :] = pdbox[:, 1, :, :] + pdbox[:, 3, :, :]       # 边框坐标y2

        # 计算原图坐标
        imghw = imghws.unsqueeze([2,3])           # 增加高宽维度,形状为[b,2,1,1]

        pdbox[:, 0, :, :] = pdbox[:, 0, :, :] * imghw[:, 1, :, :] # 边框坐标x1
        pdbox[:, 1, :, :] = pdbox[:, 1, :, :] * imghw[:, 0, :, :] # 边框坐标y1
        pdbox[:, 2, :, :] = pdbox[:, 2, :, :] * imghw[:, 1, :, :] # 边框坐标x2
        pdbox[:, 3, :, :] = pdbox[:, 3, :, :] * imghw[:, 0, :, :] # 边框坐标y2

        # 调整边框形状
        b, n, h, w = pdbox.shape             # 获取边框形状
        pdbox = pdbox.reshape([b, n, h * w]) # 调整边框形状，形状为[b,4,h*w]

        return pdbox

    def get_pdsco(self, pdobj, pdcls):
        """
        计算预测得分
        * 参数:
        - pdobj: 预测物体
        - pdcls: 预测类别
        * 返回:
        - pdsco: 预测得分
        """
        # 计算预测得分
        pdobj = paddle.nn.functional.sigmoid(pdobj) # 预测物体，形状为[b,1,h,w]
        pdcls = paddle.nn.functional.sigmoid(pdcls) # 预测类别，形状为[b,c,h,w]
        pdsco = pdobj * pdcls                       # 预测得分，形状为[b,c,h,w]

        # 调整得分形状
        b, c, h, w = pdsco.shape             # 获取得分形状
        pdsco = pdsco.reshape([b, c, h * w]) # 调整得分形状，形状为[b,c,h*w]

        return pdsco

#########################################################################################

    def multi_class_nms(self, boxes, score):
        """
        多类非极大值抑制
        * 参数:
        - boxes : 预测边框
        - socre : 预测得分
        * 返回:
        - infers: 预测结果
        """
        # 读取边框得分
        boxes = boxes.numpy()  # 预测边框，形状为[b,4,n]
        score = score.numpy()  # 预测得分，形状为[b,c,n]

        # 计算预测结果
        infers = [] # 预测结果列表
        for b in range(score.shape[0]): # 遍历批次
            # 计算每类结果
            infer_list = [] # 每类预测列表
            for i in range(score.shape[1]): # 遍历类别
                # 计算得分抑制掩码
                mask = score[b, i, :] > self.sco_threshs
                if mask.sum() == 0: # 是否得分抑制掩码为零
                    continue
                
                # 计算非极大值抑制
                mask_boxes = boxes[b, :, mask]          # 掩码预测边框，形状为[m,4]
                mask_score = score[b, i, mask]          # 掩码预测得分，形状为[m]
                keep = self.nms(mask_boxes, mask_score) # 非极大值抑制
                if len(keep) == 0:  # 是否非极大值抑制为零
                    continue

                # 添加预测结果列表
                keep_class = np.ones((len(keep), 1)) * i  # 预测结果类别，形状为[k,1]
                keep_score = mask_score[keep, np.newaxis] # 预测结果得分，形状为[k,1]
                keep_boxes = mask_boxes[keep, :]          # 预测结果边框，形状为[k,4]
                
                keep_infer = [keep_class, keep_score, keep_boxes] # 预测结果列表
                keep_infer = np.concatenate(keep_infer, axis=-1)  # 连接预测结果，形状为[k,6]
                infer_list.append(keep_infer)                     # 添加抑制结果
            
            # 添加预测列表
            if len(infer_list) == 0:      # 是否每类预测列表为零
                infers.append(infer_list) # 添加空的预测结果列表
            else:
                infer_list = np.concatenate(infer_list, axis=0) # 合并每类预测列表，形状为[b,n,6]
                infers.append(infer_list)                       # 添加预测结果列表

        return infers

    def nms(self, boxes, score):
        """
        单类非极大值抑制
        * 参数:
        - boxes: 预测边框
        - socre: 预测得分
        * 返回:
        - keep : 抑制掩码
        """
        # 计算面积索引
        x1 = boxes[:, 0]
        y1 = boxes[:, 1]
        x2 = boxes[:, 2]
        y2 = boxes[:, 3]
        
        areas = (x2 - x1 + 1) * (y2 - y1 + 1)
        order = score.argsort()[::-1]

        # 计算抑制掩码
        keep = []
        while order.size > 0:
            # 添加保留列表
            i = order[0]
            keep.append(i)
            
            # 计算交集面积
            xx1 = np.maximum(x1[i], x1[order[1:]])
            yy1 = np.maximum(y1[i], y1[order[1:]])
            xx2 = np.maximum(x2[i], x2[order[1:]])
            yy2 = np.maximum(y2[i], y2[order[1:]])

            w = np.maximum(0.0, xx2 - xx1 + 1)
            h = np.maximum(0.0, yy2 - yy1 + 1)
            inter = w * h

            # 计算交并比值
            ovr = inter / (areas[i] + areas[order[1:]] - inter)

            # 计算剩余索引
            inds = np.where(ovr <= self.nms_threshs)[0]
            order = order[inds + 1]

        return keep